#include "MyTest.H"

#include <AMReX_MLEBABecLap.H>
#include <AMReX_ParmParse.H>
#include <AMReX_MultiFabUtil.H>
#include <AMReX_EBMultiFabUtil.H>
#include <AMReX_PlotFileUtil.H>
#include <AMReX_EB2.H>

#include <cmath>

using namespace amrex;

MyTest::MyTest()
{
  readParameters();

  initGrids();

  initializeEB();

  initData();
}

void MyTest::solve()
{
  for (int ilev = 0; ilev <= max_level; ++ilev) {
    const MultiFab &vfrc = factory[ilev]->getVolFrac();
    MultiFab        v(vfrc.boxArray(), vfrc.DistributionMap(), 1, 0, MFInfo(), *factory[ilev]);
    MultiFab::Copy(v, vfrc, 0, 0, 1, 0);
    amrex::EB_set_covered(v, 1.0);
    amrex::Print() << "vfrc min = " << v.min(0) << std::endl;
  }

  std::array<LinOpBCType, AMREX_SPACEDIM> mlmg_lobc;
  std::array<LinOpBCType, AMREX_SPACEDIM> mlmg_hibc;
  for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
    if (geom[0].isPeriodic(idim)) {
      mlmg_lobc[idim] = LinOpBCType::Periodic;
      mlmg_hibc[idim] = LinOpBCType::Periodic;
    } else {
      mlmg_lobc[idim] = LinOpBCType::Dirichlet;
      mlmg_hibc[idim] = LinOpBCType::Dirichlet;
    }
  }

  LPInfo info;
  info.setMaxCoarseningLevel(max_coarsening_level);

  MLEBABecLap mleb(geom, grids, dmap, info, amrex::GetVecOfConstPtrs(factory));
  mleb.setMaxOrder(linop_maxorder);

  mleb.setDomainBC(mlmg_lobc, mlmg_hibc);

  for (int ilev = 0; ilev <= max_level; ++ilev) mleb.setLevelBC(ilev, &phi[ilev]);

  mleb.setScalars(scalars[0], scalars[1]);

  for (int ilev = 0; ilev <= max_level; ++ilev) {
    mleb.setACoeffs(ilev, acoef[ilev]);
    mleb.setBCoeffs(ilev, amrex::GetArrOfConstPtrs(bcoef[ilev]));
  }

  if (eb_is_dirichlet) {
    for (int ilev = 0; ilev <= max_level; ++ilev) mleb.setEBDirichlet(ilev, phi[ilev], bcoef_eb[ilev]);
  }

  MLMG mlmg(mleb);
  mlmg.setMaxIter(max_iter);
  mlmg.setMaxFmgIter(max_fmg_iter);
  mlmg.setBottomMaxIter(max_bottom_iter);
  mlmg.setBottomTolerance(bottom_reltol);
  mlmg.setVerbose(verbose);
  mlmg.setBottomVerbose(bottom_verbose);
  if (use_hypre) mlmg.setBottomSolver(MLMG::BottomSolver::hypre);
  if (use_petsc) mlmg.setBottomSolver(MLMG::BottomSolver::petsc);
  const Real tol_rel = reltol;
  const Real tol_abs = 0.0;
  mlmg.solve(amrex::GetVecOfPtrs(phi), amrex::GetVecOfConstPtrs(rhs), tol_rel, tol_abs);
}

void MyTest::writePlotfile()
{
  Vector<MultiFab> plotmf(max_level + 1);
  for (int ilev = 0; ilev <= max_level; ++ilev) {
    const MultiFab &vfrc = factory[ilev]->getVolFrac();
    plotmf[ilev].define(grids[ilev], dmap[ilev], 2, 0);
    MultiFab::Copy(plotmf[ilev], phi[ilev], 0, 0, 1, 0);
    MultiFab::Copy(plotmf[ilev], vfrc, 0, 1, 1, 0);
  }
  WriteMultiLevelPlotfile(plot_file_name, max_level + 1, amrex::GetVecOfConstPtrs(plotmf), {"phi", "vfrac"}, geom, 0.0, Vector<int>(max_level + 1, 0), Vector<IntVect>(max_level, IntVect{2}));
}

void MyTest::readParameters()
{
  ParmParse pp;
  pp.query("max_level", max_level);
  pp.query("n_cell", n_cell);
  pp.query("max_grid_size", max_grid_size);
  pp.query("is_periodic", is_periodic);
  pp.query("eb_is_dirichlet", eb_is_dirichlet);

  pp.query("plot_file", plot_file_name);

  scalars.resize(2);
  if (is_periodic) {
    scalars[0] = 0.0;
    scalars[1] = 1.0;
  } else {
    scalars[0] = 1.0;
    scalars[1] = 1.0;
  }
  pp.queryarr("scalars", scalars);

  pp.query("verbose", verbose);
  pp.query("bottom_verbose", bottom_verbose);
  pp.query("max_iter", max_iter);
  pp.query("max_fmg_iter", max_fmg_iter);
  pp.query("max_bottom_iter", max_bottom_iter);
  pp.query("bottom_reltol", bottom_reltol);
  pp.query("reltol", reltol);
  pp.query("linop_maxorder", linop_maxorder);
  pp.query("max_coarsening_level", max_coarsening_level);
#ifdef AMREX_USE_HYPRE
  pp.query("use_hypre", use_hypre);
#endif
#ifdef AMREX_USE_PETSC
  pp.query("use_petsc", use_petsc);
#endif
}

void MyTest::initGrids()
{
  int nlevels = max_level + 1;
  geom.resize(nlevels);
  grids.resize(nlevels);

  RealBox                         rb({AMREX_D_DECL(0., 0., 0.)}, {AMREX_D_DECL(1., 1., 1.)});
  std::array<int, AMREX_SPACEDIM> isperiodic{AMREX_D_DECL(is_periodic, is_periodic, is_periodic)};
  Geometry::Setup(&rb, 0, isperiodic.data());
  Box domain0(IntVect{AMREX_D_DECL(0, 0, 0)}, IntVect{AMREX_D_DECL(n_cell - 1, n_cell - 1, n_cell - 1)});
  Box domain = domain0;
  for (int ilev = 0; ilev < nlevels; ++ilev) {
    geom[ilev].define(domain);
    domain.refine(ref_ratio);
  }

  domain = domain0;
  for (int ilev = 0; ilev < nlevels; ++ilev) {
    grids[ilev].define(domain);
    grids[ilev].maxSize(max_grid_size);
    domain.grow(-n_cell / 4); // fine level cover the middle of the coarse domain
    domain.refine(ref_ratio);
  }
}

void MyTest::initData()
{
  int nlevels = max_level + 1;
  dmap.resize(nlevels);
  factory.resize(nlevels);
  phi.resize(nlevels);
  rhs.resize(nlevels);
  acoef.resize(nlevels);
  bcoef.resize(nlevels);
  bcoef_eb.resize(nlevels);

  for (int ilev = 0; ilev < nlevels; ++ilev) {
    dmap[ilev].define(grids[ilev]);
    const EB2::IndexSpace &eb_is    = EB2::IndexSpace::top();
    const EB2::Level      &eb_level = eb_is.getLevel(geom[ilev]);
    factory[ilev].reset(new EBFArrayBoxFactory(eb_level, geom[ilev], grids[ilev], dmap[ilev], {2, 2, 2}, EBSupport::full));

    phi[ilev].define(grids[ilev], dmap[ilev], 1, 1, MFInfo(), *factory[ilev]);
    rhs[ilev].define(grids[ilev], dmap[ilev], 1, 0, MFInfo(), *factory[ilev]);
    acoef[ilev].define(grids[ilev], dmap[ilev], 1, 0, MFInfo(), *factory[ilev]);
    for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) bcoef[ilev][idim].define(amrex::convert(grids[ilev], IntVect::TheDimensionVector(idim)), dmap[ilev], 1, 0, MFInfo(), *factory[ilev]);
    if (eb_is_dirichlet) {
      bcoef_eb[ilev].define(grids[ilev], dmap[ilev], 1, 0, MFInfo(), *factory[ilev]);
      bcoef_eb[ilev].setVal(1.0);
    }

    phi[ilev].setVal(0.0);
    rhs[ilev].setVal(0.0);
    acoef[ilev].setVal(1.0);
    for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) bcoef[ilev][idim].setVal(1.0);

    const auto dx = geom[ilev].CellSizeArray();

    if (is_periodic) {
      const Real pi = 4.0 * std::atan(1.0);

      for (MFIter mfi(rhs[ilev]); mfi.isValid(); ++mfi) {
        const Box          &bx  = mfi.fabbox();
        Array4<Real> const &fab = rhs[ilev].array(mfi);
        amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept {
          Real rx      = (i + 0.5) * dx[0];
          Real ry      = (j + 0.5) * dx[1];
          fab(i, j, k) = std::sin(rx * 2. * pi + 43.5) * std::sin(ry * 2. * pi + 89.);
        });
      }
    } else if (eb_is_dirichlet) {
      phi[ilev].setVal(10.0);
      phi[ilev].setVal(0.0, 0, 1, 0); // set interior
    } else {
      // Initialize Dirichlet boundary
      for (MFIter mfi(phi[ilev]); mfi.isValid(); ++mfi) {
        const Box          &bx  = mfi.fabbox();
        Array4<Real> const &fab = phi[ilev].array(mfi);
        amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept {
          Real rx      = (i + 0.5) * dx[0];
          Real ry      = (j + 0.5) * dx[1];
          fab(i, j, k) = std::sqrt(0.5) * (rx + ry);
        });
      }
      phi[ilev].setVal(0.0, 0, 1, 0); // set interior to zero
    }
  }
}
